{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zb9d5N92iMz2"
      },
      "source": [
        "# Chroma quickstart\n",
        "\n",
        "First, run the [setup cell](#setup) below. Then, run [this cell](#unconditional-chain) to get a Chroma sample. Further examples are below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KaT878cbeOQm"
      },
      "outputs": [],
      "source": [
        "# @title Setup\n",
        "\n",
        "# @markdown [Get your API key here](https://chroma-weights.generatebiomedicines.com) and enter it below before running.\n",
        "\n",
        "from google.colab import output\n",
        "\n",
        "output.enable_custom_widget_manager()\n",
        "\n",
        "import os\n",
        "\n",
        "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
        "import contextlib\n",
        "\n",
        "api_key = \"\"  # @param {type:\"string\"}\n",
        "\n",
        "!pip install generate-chroma > /dev/null 2>&1\n",
        "\n",
        "import torch\n",
        "\n",
        "torch.use_deterministic_algorithms(True, warn_only=True)\n",
        "\n",
        "import warnings\n",
        "from tqdm import tqdm, TqdmExperimentalWarning\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=TqdmExperimentalWarning)\n",
        "from functools import partialmethod\n",
        "\n",
        "tqdm.__init__ = partialmethod(tqdm.__init__, leave=False)\n",
        "\n",
        "from google.colab import files\n",
        "import ipywidgets as widgets\n",
        "\n",
        "\n",
        "def create_button(filename, description=\"\"):\n",
        "    button = widgets.Button(description=description)\n",
        "    display(button)\n",
        "\n",
        "    def on_button_click(b):\n",
        "        files.download(filename)\n",
        "\n",
        "    button.on_click(on_button_click)\n",
        "\n",
        "\n",
        "def render(protein, trajectories, output=\"protein.cif\"):\n",
        "    display(protein)\n",
        "    print(protein)\n",
        "    protein.to_CIF(output)\n",
        "    traj_output = output.replace(\".cif\", \"_trajectory.cif\")\n",
        "    trajectories[\"trajectory\"].to_CIF(traj_output)\n",
        "    create_button(output, description=\"Download sample\")\n",
        "    create_button(traj_output, description=\"Download trajectory\")\n",
        "\n",
        "\n",
        "import locale\n",
        "\n",
        "locale.getpreferredencoding = lambda: \"UTF-8\"\n",
        "\n",
        "from chroma import Chroma, Protein, conditioners\n",
        "from chroma.models import graph_classifier, procap\n",
        "from chroma.utility.api import register_key\n",
        "from chroma.utility.chroma import letter_to_point_cloud, plane_split_protein\n",
        "\n",
        "register_key(api_key)\n",
        "with contextlib.redirect_stdout(None):\n",
        "    chroma = Chroma()\n",
        "\n",
        "device = \"cuda\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sfNODIk5BZEH"
      },
      "outputs": [],
      "source": [
        "# @title Get a protein! <a name=\"unconditional-chain\"></a> {display-mode: \"form\"}\n",
        "\n",
        "# @markdown Specify the desired length. Chroma will output a fully designed single chain protein.\n",
        "# @markdown As with all examples in this notebook, the trajectory can also be downloaded.\n",
        "\n",
        "length = 160  # @param {type:\"slider\", min:50, max:250, step:10}\n",
        "\n",
        "protein, trajectories = chroma.sample(\n",
        "    chain_lengths=[length], steps=200, full_output=True\n",
        ")\n",
        "render(protein, trajectories, output=\"protein.cif\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "14AUlQHIxPle"
      },
      "source": [
        "## Conditional generation\n",
        "\n",
        "After running the setup at the top of the notebook, all examples are completely independent.\n",
        "\n",
        "[Single chain](#unconditional-chain): the simplest example of protein generation with Chroma.\n",
        "\n",
        "[Complex](#unconditional-complex): a protein with multiple chains.\n",
        "\n",
        "[Symmetry](#symmetry): a symmetric complex, where the symmetry group and subunit size can be input by the user.\n",
        "\n",
        "[Substructure](#substructure): infilling a PDB structure, where the residues to design can be specified by a PyMOL-style string.\n",
        "\n",
        "[Shape](#shape): Chroma generation conditioned on shape, using letters as an example.\n",
        "\n",
        "[Topology](#proclass-chain): chain-level conditioning using ProClass, where a CAT code can be specified.\n",
        "\n",
        "[Secondary structure](#proclass-residue): ProClass also provides conditioning of secondary structure, which can be input as a per-residue string.\n",
        "\n",
        "[Natural language](#procap): ProCap takes a user caption in order to condition Chroma generation.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KLbZia40CQhK"
      },
      "outputs": [],
      "source": [
        "# @title Complexes <a name=\"unconditional-complex\"></a> {display-mode: \"form\"}\n",
        "\n",
        "# @markdown Given the lengths of individual chains, Chroma can generate a complex.\n",
        "\n",
        "chain1_length = 400  # @param {type:\"slider\", min:100, max:500, step:10}\n",
        "chain2_length = 100  # @param {type:\"slider\", min:0, max:200, step:10}\n",
        "chain3_length = 100  # @param {type:\"slider\", min:0, max:200, step:10}\n",
        "chain4_length = 100  # @param {type:\"slider\", min:0, max:200, step:10}\n",
        "\n",
        "protein, trajectories = chroma.sample(\n",
        "    chain_lengths=[chain1_length, chain2_length, chain3_length, chain4_length],\n",
        "    steps=200,\n",
        "    full_output=True,\n",
        ")\n",
        "render(protein, trajectories, output=\"complex.cif\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7FpA3VzWWTM_"
      },
      "outputs": [],
      "source": [
        "# @title Symmetry <a name=\"symmetry\"></a> {display-mode: \"form\"}\n",
        "\n",
        "# @markdown Specify the desired symmetry type and the size of a single subunit.\n",
        "\n",
        "symmetry_group = \"C_7\"  # @param [\"C_2\", \"C_3\", \"C_4\", \"C_5\", \"C_6\", \"C_7\", \"C_8\", \"D_2\", \"D_3\", \"D_4\", \"D_5\", \"D_6\", \"D_7\", \"D_8\", \"T\", \"O\", \"I\"]\n",
        "subunit_size = 100  # @param {type:\"slider\", min:10, max:150, step:5}\n",
        "knbr = 2\n",
        "\n",
        "conditioner = conditioners.SymmetryConditioner(\n",
        "    G=symmetry_group, num_chain_neighbors=knbr\n",
        ")\n",
        "symmetric_protein, trajectories = chroma.sample(\n",
        "    chain_lengths=[subunit_size],\n",
        "    conditioner=conditioner,\n",
        "    langevin_factor=8,\n",
        "    inverse_temperature=8,\n",
        "    sde_func=\"langevin\",\n",
        "    potts_symmetry_order=conditioner.potts_symmetry_order,\n",
        "    full_output=True,\n",
        ")\n",
        "render(symmetric_protein, trajectories, output=\"symmetric_protein.cif\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BCTxghg19meL"
      },
      "outputs": [],
      "source": [
        "# @title Substructure <a name=\"substructure\"></a> {display-mode: \"form\"}\n",
        "\n",
        "# @markdown Enter a PDB ID and a selection string corresponding to designable positions.\n",
        "# @markdown Using a substructure conditioner, Chroma can design at these positions while holding the rest of the structure fixed.\n",
        "# @markdown The default selection cuts the protein in half and fills it in.\n",
        "# @markdown Other selections, by position or proximity, are also allowed.\n",
        "\n",
        "pdb_id = \"5SV5\"  # @param ['5SV5', '6QAZ', '3BDI'] {allow-input:true}\n",
        "\n",
        "try:\n",
        "    protein = Protein.from_PDBID(pdb_id, canonicalize=True, device=device)\n",
        "except FileNotFoundError:\n",
        "    print(\"Invalid PDB ID! Using 3BDI\")\n",
        "    pdb_id = \"3BDI\"\n",
        "    protein = Protein.from_PDBID(pdb_id, canonicalize=True, device=device)\n",
        "\n",
        "X, C, _ = protein.to_XCS()\n",
        "selection_string = \"namesel infilling_selection\"  # @param ['namesel infilling_selection', 'z > 16', '(resid 50) around 10'] {allow-input:true}\n",
        "residues_to_design = plane_split_protein(X, C, protein, 0.5).nonzero()[:, 1].tolist()\n",
        "protein.sys.save_selection(gti=residues_to_design, selname=\"infilling_selection\")\n",
        "\n",
        "try:\n",
        "    conditioner = conditioners.SubstructureConditioner(\n",
        "        protein, backbone_model=chroma.backbone_network, selection=selection_string\n",
        "    ).to(device)\n",
        "except Exception:\n",
        "    print(\"Error initializing conditioner! Falling back to masking 50% of residues.\")\n",
        "    selection_string = \"namesel infilling_selection\"\n",
        "    conditioner = conditioners.SubstructureConditioner(\n",
        "        protein,\n",
        "        backbone_model=chroma.backbone_network,\n",
        "        selection=selection_string,\n",
        "        rg=True,\n",
        "    ).to(device)\n",
        "\n",
        "infilled_protein, trajectories = chroma.sample(\n",
        "    protein_init=protein,\n",
        "    conditioner=conditioner,\n",
        "    langevin_factor=4.0,\n",
        "    langevin_isothermal=True,\n",
        "    inverse_temperature=8.0,\n",
        "    steps=500,\n",
        "    sde_func=\"langevin\",\n",
        "    full_output=True,\n",
        ")\n",
        "render(infilled_protein, trajectories, output=\"infilled_protein.cif\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h5C-SrJEBqIs"
      },
      "outputs": [],
      "source": [
        "# @title Shape <a name=\"shape\"></a> {display-mode: \"form\"}\n",
        "\n",
        "# @markdown Create a protein in the shape of a desired character of arbitrary length.\n",
        "\n",
        "character = \"G\"  # @param {type:\"string\"}\n",
        "if len(character) > 1:\n",
        "    character = character[:1]\n",
        "    print(f\"Keeping only first character ({character})!\")\n",
        "length = 1000  # @param {type:\"slider\", min:100, max:1500, step:100}\n",
        "\n",
        "letter_point_cloud = letter_to_point_cloud(character)\n",
        "conditioner = conditioners.ShapeConditioner(\n",
        "    letter_point_cloud,\n",
        "    chroma.backbone_network.noise_schedule,\n",
        "    autoscale_num_residues=length,\n",
        ").to(device)\n",
        "\n",
        "shaped_protein, trajectories = chroma.sample(\n",
        "    chain_lengths=[length], conditioner=conditioner, full_output=True\n",
        ")\n",
        "\n",
        "render(shaped_protein, trajectories, output=\"shaped_protein.cif\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M-5X_saooA6J"
      },
      "outputs": [],
      "source": [
        "# @title Fold <a name=\"proclass-chain\"></a> {display-mode: \"form\"}\n",
        "\n",
        "# @markdown Input a [CATH number](https://cathdb.info/browse) to get chain-level conditioning, e.g. `3.40.50` for a Rossmann fold or `2` for mainly beta.\n",
        "\n",
        "CATH = \"3.40.50\"  # @param {type:\"string\"}\n",
        "length = 130  # @param {type:\"slider\", min:50, max:250, step:10}\n",
        "\n",
        "proclass_model = graph_classifier.load_model(\"named:public\", device=device)\n",
        "conditioner = conditioners.ProClassConditioner(\"cath\", CATH, model=proclass_model)\n",
        "cath_conditioned_protein, trajectories = chroma.sample(\n",
        "    conditioner=conditioner, chain_lengths=[length], full_output=True\n",
        ")\n",
        "render(cath_conditioned_protein, trajectories, output=\"cath_conditioned_protein.cif\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A2012-PnoTHf"
      },
      "outputs": [],
      "source": [
        "# @title Secondary structure <a name=\"proclass-residue\"></a> {display-mode: \"form\"}\n",
        "\n",
        "# @markdown Enter a string to specify residue-level secondary structure conditioning: H = helix, E = strand, T = turn.\n",
        "\n",
        "SS = \"HHHHHHHTTTHHHHHHHTTTEEEEEETTTEEEEEEEETTTTHHHHHHHH\"  # @param {type:\"string\"}\n",
        "\n",
        "proclass_model = graph_classifier.load_model(\"named:public\", device=device)\n",
        "conditioner = conditioners.ProClassConditioner(\n",
        "    \"secondary_structure\", SS, max_norm=None, model=proclass_model\n",
        ")\n",
        "ss_conditioned_protein, trajectories = chroma.sample(\n",
        "    steps=500, conditioner=conditioner, chain_lengths=[len(SS)], full_output=True\n",
        ")\n",
        "render(ss_conditioned_protein, trajectories, output=\"ss_conditioned_protein.cif\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ix41mhyEbLTF"
      },
      "outputs": [],
      "source": [
        "# @title Natural text <a name=\"procap\"></a> {display-mode: \"form\"}\n",
        "\n",
        "# @markdown ProCap uses natural language captions to condition samples.\n",
        "\n",
        "length = 110  # @param {type:\"slider\", min:50, max:250, step:10}\n",
        "caption = \"Crystal structure of SH2 domain\"  # @param {type:\"string\"}\n",
        "\n",
        "procap_model = procap.load_model(\"named:public\", device=device)\n",
        "conditioner = conditioners.ProCapConditioner(caption, -1, model=procap_model)\n",
        "caption_conditioned_protein, trajectories = chroma.sample(\n",
        "    steps=200, chain_lengths=[length], conditioner=conditioner, full_output=True\n",
        ")\n",
        "render(\n",
        "    caption_conditioned_protein, trajectories, output=\"caption_conditioned_protein.cif\"\n",
        ")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
